#!/usr/bin/env python3

"""
    Portchannel member ping: get portchannel member packet loss state.
"""

import argparse
import ipaddress
import random
import select
import socket
import sys
import time

from swsscommon.swsscommon import ConfigDBConnector
from scapy.all import *


RCV_SIZE = 20480
RCV_TIMEOUT = 10000
SELECT_TIMEOUT = 2
PROBE_PACKET_LEN = 9000
DHCP_DEFAULT = 0

def receive(my_socket, timeout, exp_pkt, time_sent):
    timeLeft = timeout
    delay = 0
    while True:
        startedSelect = time.time()
        whatReady = select.select(my_socket, [], [], timeLeft)
        if whatReady[0] == []: # Timeout
            return False, delay, None
        timeReceived = time.time()
        
        for sel in whatReady[0]:
            recPacket = sel.recv(RCV_SIZE)
            e = str(exp_pkt)
            # exclude ethernet header
            p = str(recPacket[14:])
            if e == p:
                delay = (timeReceived - time_sent) * 1000
                return True, delay, sel
        timeLeft = timeLeft - (timeReceived - startedSelect)
        if timeLeft <= 0:
            return False, delay, None
        
def find_probe_packet(interface, dst_out, dst_in, sockets, exp_socket, max_iter):
    dscp = DHCP_DEFAULT
    tos = dscp << 2
    src_out = dst_in
    decap_valid = False
    for x in range(0, max_iter):
        print("---------- Attempting to create probe packet that goes through %s, iteration: %d ---------" % (interface, x))
        ip_src = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
        ip_src = ipaddress.IPv4Address(ip_src)
        while ip_src == ipaddress.IPv4Address(dst_in) or \
              ip_src.is_multicast or ip_src.is_private or ip_src.is_reserved:
            ip_src = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
            ip_src = ipaddress.IPv4Address(ip_src)
        inner_pkt = IP(src=ip_src, dst=dst_in, tos=tos, ttl=64, id=1, ihl=None,proto=4)/ICMP()
        inner_pkt = inner_pkt/("".join([chr(x%256) for x in range(PROBE_PACKET_LEN - len(inner_pkt))]))
        pkt = IP(src=src_out, dst=dst_out, tos=tos)/inner_pkt
        exp_pkt = inner_pkt
        exp_pkt['IP'].ttl = 63

        time_sent = time.time()
        res = send(pkt, iface=interface, verbose=False)
        have_received, delay, recv_socket = receive(sockets, SELECT_TIMEOUT, exp_pkt, time_sent)
        if have_received and recv_socket == exp_socket:
            return pkt, exp_pkt
        elif have_received:
            decap_valid = True

    if decap_valid:
       print("Decap works properly. link %s is unreachable, Please check!\n" % interface)
       sys.exit(0)
    else:
       print("Decap using %s doesn't work properly, Please Check!\n" % (dst_out))
       sys.exit(0)

def get_portchannel(interface):
    configdb = ConfigDBConnector()
    configdb.connect()
    portchannels = configdb.get_table("PORTCHANNEL")
    for key, value in portchannels.items():
        if interface in value['members']:
            return key, value
    sys.stderr.write("Interface %s is not a portchannel member. Please Check!\n" % interface)
    sys.exit(1)

def is_ip_prefix_in_key(key):
    '''
    Function to check if IP address is present in the key. If it
    is present, then the key would be a tuple or else, it shall be
    be string
    '''
    return (isinstance(key, tuple))

def get_portchannel_ipv4(portchannel_name):
    configdb = ConfigDBConnector()
    configdb.connect()
    config = configdb.get_config()
    portchannel_interfaces = config["PORTCHANNEL_INTERFACE"]
    for key in portchannel_interfaces:
        if not is_ip_prefix_in_key(key):
            continue
        pc, ip = key
        ip = ip.split("/")[0]
        if pc == portchannel_name and ipaddress.ip_address(ip).version == 4:
            return ip
    sys.stderr.write("Portchannel %s doesn't have IPV4 address. Please Check!\n" % portchannel)
    sys.exit(1)
   
def create_socket(interface, portchannel):
    # create sockets
    try:
        sockets = []
        for iface in portchannel['members']:
            s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(ETH_P_IP))
            s.bind((iface, 0))
            s.settimeout(RCV_TIMEOUT)
            if iface == interface:
                exp_socket = s
            sockets.append(s)
    except Exception:
        sys.stderr.write("Unable to create socket. Check your permissions\n")
        sys.exit(1)
    return sockets, exp_socket

def print_statistics(interface, total_count, recv_count):
    print("\n--- %s ping statistics ---" % interface)
    print("%d packets transmitted, %d received, %0.2f%% packet loss" % (total_count, recv_count, (total_count-recv_count)*1.0/total_count*100))

def main():
    parser = argparse.ArgumentParser(description='Check portchannel member link state',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog=""" 
    Example:
    pcmping -i Ethernet8 -ip 10.10.10.10
    pcmping -i Ethernet8 -ip 10.10.10.10 -c 5
    pcmping -i Ethernet8 -ip 10.10.10.10 -c 5 -t 20
"""
    )

    parser.add_argument('-i', '--interface', type=str, required=True, help='Portchannel member interface')
    parser.add_argument('-ip','--decapip', type=str, required=True, help='Neighbor decap ip (usually same as neighbor loopback ip)') 
    parser.add_argument('-c', '--count', type=int, help='Number of decap packets to be sent.', default=0)
    parser.add_argument('-t', '--trial', type=int, help='Maximum attempt times to create probe packet.\nPlease increase the value as the number of portchannel members increases', default=20)
    parser.add_argument('-v', '--version', action='version', version='%(prog)s 1.0')
    args = parser.parse_args()
    
    interface = args.interface
    decap_ip = args.decapip
    exp_count = args.count
    max_trial = args.trial

    portchannel_name, portchannel = get_portchannel(interface)
    portchannel_ip = get_portchannel_ipv4(portchannel_name)
    sockets, exp_socket = create_socket(interface, portchannel)
    pkt, exp_pkt = find_probe_packet(interface, decap_ip, portchannel_ip, sockets, exp_socket, max_trial)
    print("Preparation done! Start probing ............")
    print("PORTCHANNEL Member PING %d btyes of data. PORTCHANNEL: %s, INTERFACE: %s" % (PROBE_PACKET_LEN, portchannel_name, interface))
    recv_count = 0
    total_count = 0
    try:
        while True:
            time_sent = time.time()
            res = send(pkt, iface=interface, verbose=False)
            have_received, delay, recv_socket = receive(sockets, SELECT_TIMEOUT, exp_pkt, time_sent)
            total_count = total_count + 1
            if have_received:
                print("%d bytes from %s: time=%0.4fms" % (PROBE_PACKET_LEN, interface, delay))
                recv_count = recv_count + 1
            else:
                print("packet not received!")
            if total_count == exp_count:
                break;
            time.sleep(1)
    except KeyboardInterrupt:
        print_statistics(interface, total_count, recv_count)
        return
    print_statistics(interface, total_count, recv_count)

if __name__ == '__main__':
    main()
